open Interface
open Instantiator

(* Access interpreter functions from Velus *)
module type Interpreter = sig
  (** Interpreter syntax *)
  type syn

  (** Interpreter state *)
  type state

  (** Init the interpreter from the AST the program *)
  val reset : syn -> string -> state

  (** Call the given node, take a step *)
  val step : syn -> string -> state -> OpAux.svalue list -> (OpAux.svalue list * state)
end

exception InterpreterError of string
exception ParseInputError of string
let () =
  Printexc.register_printer
    (function
      | InterpreterError s -> Some (Printf.sprintf "Interpreter error: %s" s)
      | ParseInputError v -> Some (Printf.sprintf "Couldn't parse input %s" v)
      | _ -> None)

let interpreter_error (msg : Errors.errcode list) =
  Format.fprintf Format.str_formatter "%a" Driveraux.print_error msg;
  raise (InterpreterError (Format.flush_str_formatter ()))

let ident_of_str s = Ident.str_to_pos @@ Camlcoq.coqstring_of_camlstring @@ s
let str_of_ident x = Camlcoq.camlstring_of_coqstring @@ Ident.pos_to_str @@ x

(* Interpreting Obc *)
module ObcInterpreter : Interpreter with type syn = Obc.Syn.program = struct
  type syn = Obc.Syn.program
  type state = Obc.Sem.menv

  let reset p cname =
    let cname = ident_of_str cname in
    match Obc.Int.eval_method p VelusMemory.empty_memory cname Ident.Ids.reset [] with
    | Error msg -> interpreter_error msg
    | OK (s, _) -> s

  let step p cname me ins =
    let cname = ident_of_str cname in
    let ins = List.map (function OpAux.Coq_present v -> Some v | OpAux.Coq_absent -> None) ins in
    match Obc.Int.eval_method p me cname Ident.Ids.step ins with
    | Error msg -> interpreter_error msg
    | OK (s, outs) ->
      let outs = List.map (function Some v -> OpAux.Coq_present v | None -> OpAux.Coq_absent) outs in
      (outs, s)
end

(* let step_fun = ref (fun () -> ()) *)

let rec choose_default names old =
  match names, old with
  | [], _ -> invalid_arg "choose_default"
  | [n], _ -> n
  | hd::tl, Some old when old = hd -> hd
  | _::tl, _ -> choose_default tl old

let ( let* ) x f =
  match x with
  | Some x -> f x
  | None -> ()

let rec find_constr_num id = function
  | [] -> invalid_arg "find_constr_num"
  | hd::_ when hd = id -> 0
  | _::tl -> 1 + find_constr_num id tl

let parse_input typ s =
  let open Interface in
  try
    if s = "." then
      OpAux.Coq_absent
    else
      OpAux.Coq_present
        (match typ with
         | Op.Tenum (id, _) when id = Ident.Ids.bool_id ->
           Op.Venum (Camlcoq.Nat.of_int (if s = "T" then 1 else 0))
         | Op.Tenum (_, constrs) ->
           Op.Venum (Camlcoq.Nat.of_int (find_constr_num (ident_of_str s) constrs))
         | Op.Tprimitive (Tint _) ->
           Op.Vscalar (Values.Vint (Camlcoq.coqint_of_camlint @@ Int32.of_string @@ s))
         | Op.Tprimitive (Op.Tlong _) ->
           Op.Vscalar (Vlong (Camlcoq.coqint_of_camlint64 @@ Int64.of_string @@ s))
         | Op.Tprimitive (Op.Tfloat _) ->
           Op.Vscalar (Vfloat (Camlcoq.coqfloat32_of_camlfloat @@ float_of_string @@ s)))
  with _ -> raise (ParseInputError s)

let parse_inputs typs s =
  let ss = String.split_on_char ' ' s in
  List.map2 parse_input typs ss

let string_of_val = function
  | Values.Vundef -> "undefined"
  | Values.Vint i -> Int32.to_string (Camlcoq.camlint_of_coqint i)
  | Values.Vlong l -> Int64.to_string (Camlcoq.camlint64_of_coqint l)
  | Values.Vfloat f -> Float.to_string (Camlcoq.camlfloat_of_coqfloat f)
  | _ -> invalid_arg "string_of_val"

let string_of_value typ v =
  match typ, v with
  | _, Op.Vscalar v -> string_of_val v
  | Op.Tenum (tx, _), Op.Venum e when tx = Ident.Ids.bool_id ->
    let n = Camlcoq.Nat.to_int e in if n = 1 then "\\true{}" else "\\false{}"
  | Op.Tenum (tx, tconstrs), Op.Venum e ->
    let n = Camlcoq.Nat.to_int e in str_of_ident (List.nth tconstrs n)
  | _, Op.Venum e -> invalid_arg "string_of_value"

let string_of_svalue typ = function
  | OpAux.Coq_absent -> " "
  | OpAux.Coq_present v -> string_of_value typ v

exception CompileError

(** Read the fake file generated by printers *)
let read_file filename =
  let ch = open_in filename in
  let s = really_input_string ch (in_channel_length ch) in
  close_in ch;
  s

let ( let* ) r f =
  match r with
  | Errors.OK v -> f v
  | Errors.Error errmsg ->
    Format.printf "%a\n" Driveraux.print_error errmsg;
    raise CompileError

open LustreParser.MenhirLibParser.Inter

let tokens_stream s : buffer =
  let lexbuf = Lexing.from_string s in
  let rec loop () =
    Buf_cons (LustreLexer.initial lexbuf, Lazy.from_fun loop)
  in Lazy.from_fun loop

module I = LustreParser2.MenhirInterpreter

let rec parsing_loop editor toks (checkpoint : unit I.checkpoint) =
  match checkpoint with
  | I.InputNeeded env ->
    (* The parser needs a token. Request one from the lexer,
       and offer it to the parser, which will produce a new
       checkpoint. Then, repeat. *)
    let (token, loc) = Relexer.map_token (LustreParser.MenhirLibParser.Inter.buf_head toks) in
    let loc = LustreLexer.lexing_loc loc in
    let checkpoint = I.offer checkpoint (token, loc, loc) in
    parsing_loop editor (LustreParser.MenhirLibParser.Inter.buf_tail toks) checkpoint
  | I.Shifting _
  | I.AboutToReduce _ ->
    let checkpoint = I.resume checkpoint in
    parsing_loop editor toks checkpoint
  | I.HandlingError env ->
    (* The parser has suspended itself because of a syntax error. Stop. *)
    let (_, loc) = Relexer.map_token (LustreParser.MenhirLibParser.Inter.buf_head toks)
    in Printf.printf "syntax error at %s" (LustreLexer.string_of_loc loc)
  | I.Accepted v ->
    assert false (* LustreParser2 should not succeed where Parser failed. *)
  | I.Rejected ->
    (* The parser rejects this input. This cannot happen, here, because
       we stop as soon as the parser reports [HandlingError]. *)
    assert false

let reparse filename toks =
  let (_, l) = Relexer.map_token (buf_head toks) in
  parsing_loop filename toks
    (LustreParser2.Incremental.translation_unit_file (LustreLexer.lexing_loc l))

let parse_with_error filename toks =
  Diagnostics.reset();
  match LustreParser.translation_unit_file (Camlcoq.Nat.of_int 1000) toks with
  | Fail_pr_full _ -> (reparse filename toks; raise CompileError)
  | Timeout_pr -> assert false
  | Parsed_pr (ast, _) -> ast

let lex_and_parse_program filename =
  let lexbuf = tokens_stream (read_file filename) in
  parse_with_error filename lexbuf

(** Compile to Obc AST, for interpretation purposes *)
let compile_to_obc filename =
  let open Interface in
  let open Instantiator in
  let decls = lex_and_parse_program filename in
  let* l = LustreElab.elab_declarations decls in
  let* nl = Velus.l_to_nl l in
  let stc = nl |> NL.DCE.DCE.dce_global |> NL.DRR.DRR.remove_dup_regs |> NL2Stc.translate in
  let* stc = Velus.schedule_program stc in
  stc |> Stc2Obc.translate |> Obc.Fus.fuse_program |> Obc.SwN.normalize_switches

let trace_file = ref None
let output_file = ref None

let get_trace_file filename =
  Option.value ~default:((Filename.chop_extension filename)^".trace") !trace_file

let get_output_file filename =
  Option.value ~default:((Filename.chop_extension filename)^".tex") !output_file

let rec repeat_print n fmt s =
  if n = 0 then ()
  else
    Format.fprintf fmt "%s%a" s (repeat_print (n - 1)) s

let print_sep_list sep fprint p vals =
  Format.pp_print_list
    ~pp_sep:(fun p () -> Format.fprintf p sep)
    fprint p vals

let print_line fmt ((x, ty), values) =
  Format.fprintf fmt "%s & %a & \\ldots"
    x
    (print_sep_list " & "
       (fun p sv -> Format.fprintf p "%s" (string_of_svalue ty sv)))
    (List.rev values)

let main filename =
  let p = compile_to_obc filename in
  let cls = List.hd p.classes in
  let clsname = str_of_ident cls.c_name in
  let stepme = Option.get (Obc.Syn.find_method Ident.Ids.step cls.c_methods) in
  let intys = List.map snd stepme.m_in in
  let trace_in = open_in (get_trace_file filename) in
  let rec aux allins allouts state =
    try
      let s = input_line trace_in in
      let ins = parse_inputs intys s in
      let (outs, state) = ObcInterpreter.step p clsname state ins in
      aux (List.map2 (fun x y -> x::y) ins allins) (List.map2 (fun x y -> x::y) outs allouts) state
    with
    | End_of_file -> (allins, allouts)
    | e -> failwith (Printexc.to_string e);
  in
  let (ins, outs) =
    aux
      (List.map (fun _ -> []) stepme.m_in)
      (List.map (fun _ -> []) stepme.m_out)
      (ObcInterpreter.reset p clsname) in
  close_in trace_in;

  (* Printing to tex  *)
  let nb_column = List.length (List.hd ins) in
  let nb_output = List.length stepme.m_out in
  let fmt = Format.str_formatter in
  Format.fprintf fmt "\\newcommand{\\%schrono}[%d]{\n"
    (String.map (fun c -> if c = '_' then 'X' else c) clsname)
    nb_output;
  Format.fprintf fmt "\\begin{tabular}{l|%a}\n"
    (repeat_print (nb_column + 1)) "c";
  print_sep_list "\\\\\n" print_line fmt
    (List.map2 (fun (x, ty) values -> ((Printf.sprintf "\\lus{%s}" (str_of_ident x), ty), values))
       stepme.m_in ins);
  Format.fprintf fmt "\\\\\n\\hline\n";
  print_sep_list "\\\\\n" print_line fmt
    (List.combine (List.mapi (fun i (_, ty) -> (Printf.sprintf "#%d" (i + 1), ty)) stepme.m_out) outs);
  Format.fprintf fmt "\\\\\n\\end{tabular}}\n";
  let tex_out = open_out (get_output_file filename) in
  output_string tex_out (Format.flush_str_formatter ());
  close_out tex_out
  (* TODO *)

let set_trace_file s =
  trace_file := Some s

let set_output_file s =
  output_file := Some s

let spec = [
  "-trace", Arg.String set_trace_file, " Set <trace> file name";
  "-o", Arg.String set_output_file, " Set <output> file name";
]

let usage_msg =
  Format.sprintf "Usage: velustotex [options] <source>\n(arch=%s system=%s abi=%s)\n"
    Configuration.arch Configuration.system Configuration.abi

let _ = Machine.config := Machine.x86_64

let _ =
  Arg.parse spec main usage_msg
